package com.example.aspect;

import com.example.annotation.RateLimit;
import com.example.config.RateLimitService;
import com.example.exception.BusinessException;
import com.example.util.SecurityUtils;
import jakarta.servlet.http.HttpServletRequest;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

/**
 * 限流切面
 */
@Slf4j
@Aspect
@Component
@RequiredArgsConstructor
public class RateLimitAspect {

    private final RateLimitService rateLimitService;
    private final SecurityUtils securityUtils;

    @Around("@annotation(rateLimit)")
    public Object around(ProceedingJoinPoint point, RateLimit rateLimit) throws Throwable {
        String key = generateKey(rateLimit);
        
        if (!rateLimitService.isAllowed(key, rateLimit.limit(), rateLimit.window())) {
            log.warn("限流触发: key={}, limit={}, window={}", key, rateLimit.limit(), rateLimit.window());
            throw new BusinessException("RATE_LIMIT_EXCEEDED", rateLimit.message());
        }
        
        return point.proceed();
    }

    /**
     * 生成限流键
     */
    private String generateKey(RateLimit rateLimit) {
        StringBuilder keyBuilder = new StringBuilder();
        
        // 添加自定义前缀
        if (StringUtils.hasText(rateLimit.key())) {
            keyBuilder.append(rateLimit.key()).append(":");
        }
        
        // 根据限流类型生成键
        switch (rateLimit.type()) {
            case USER:
                String username = getCurrentUsername();
                if (username != null) {
                    keyBuilder.append("user:").append(username);
                } else {
                    // 如果没有用户信息，降级为IP限流
                    keyBuilder.append("ip:").append(getClientIp());
                }
                break;
                
            case IP:
                keyBuilder.append("ip:").append(getClientIp());
                break;
                
            case GLOBAL:
                keyBuilder.append("global");
                break;
                
            case CUSTOM:
                if (!StringUtils.hasText(rateLimit.key())) {
                    throw new IllegalArgumentException("自定义限流类型必须指定key");
                }
                // key已经在前面添加了，这里不需要额外处理
                break;
        }
        
        return keyBuilder.toString();
    }

    /**
     * 获取当前用户名
     */
    private String getCurrentUsername() {
        try {
            Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
            if (authentication != null && authentication.isAuthenticated() && 
                !"anonymousUser".equals(authentication.getPrincipal())) {
                return authentication.getName();
            }
        } catch (Exception e) {
            log.debug("获取当前用户名失败: {}", e.getMessage());
        }
        return null;
    }

    /**
     * 获取客户端IP
     */
    private String getClientIp() {
        try {
            ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
            if (attributes != null) {
                HttpServletRequest request = attributes.getRequest();
                return getClientIpFromRequest(request);
            }
        } catch (Exception e) {
            log.debug("获取客户端IP失败: {}", e.getMessage());
        }
        return "unknown";
    }

    /**
     * 从请求中获取客户端IP
     */
    private String getClientIpFromRequest(HttpServletRequest request) {
        String ip = request.getHeader("X-Forwarded-For");
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_CLIENT_IP");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_X_FORWARDED_FOR");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        
        // 处理多个IP的情况，取第一个
        if (ip != null && ip.contains(",")) {
            ip = ip.split(",")[0].trim();
        }
        
        // 验证IP格式
        if (ip != null && securityUtils.isValidIpAddress(ip)) {
            return ip;
        }
        
        return "unknown";
    }
}
